{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fbddb14f",
   "metadata": {},
   "source": [
    "## 自定义调用本地大模型方法\n",
    "1. **类属性定义**:\n",
    "   - `max_token`: 定义了模型可以处理的最大令牌数。\n",
    "   - `do_sample`: 指定是否在生成文本时采用采样策略。\n",
    "   - `temperature`: 控制生成文本的随机性，较高的值会产生更多随机性。\n",
    "   - `top_p`: 一种替代`temperature`的采样策略，这里设置为0.0，意味着不使用。\n",
    "   - `tokenizer`: 分词器，用于将文本转换为模型可以理解的令牌。\n",
    "   - `model`: 存储加载的模型对象。\n",
    "   - `history`: 存储对话历史。\n",
    "2. **构造函数**:\n",
    "   - `__init__`: 构造函数初始化了父类的属性。\n",
    "3. **属性方法**:\n",
    "   - `_llm_type`: 返回模型的类型，即`ChatGLM3`。\n",
    "4. **加载模型的方法**:\n",
    "   - `load_model`: 此方法用于加载模型和分词器。它首先尝试从指定的路径加载分词器，然后加载模型，并将模型设置为评估模式。这里的模型和分词器是从Hugging Face的`transformers`库中加载的。\n",
    "5. **调用方法**:\n",
    "   - `_call`: 一个内部方法，用于调用模型。它被设计为可以被子类覆盖。\n",
    "   - `invoke`: 这个方法使用模型进行聊天。它接受一个提示和一个历史记录，并返回模型的回复和更新后的历史记录。这里使用了模型的方法`chat`来生成回复，并设置了采样、最大长度和温度等参数。\n",
    "6. **流式方法**:\n",
    "   - `stream`: 这个方法允许模型逐步返回回复，而不是一次性返回所有内容。这对于长回复或者需要实时显示回复的场景很有用。它通过模型的方法`stream_chat`实现，并逐块返回回复。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "c400fed2",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.llms.base import LLM\n",
    "from transformers import AutoTokenizer,AutoModel,AutoConfig\n",
    "from langchain_core.messages.ai import AIMessage\n",
    "\n",
    "class ChatGLM3(LLM):\n",
    "    max_token:int=8192\n",
    "    do_sample:bool = True\n",
    "    temperature:float = 0.3\n",
    "    top_p = 0.0\n",
    "    tokenizer:object = None\n",
    "    model:object = None\n",
    "    history = []\n",
    "    \n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "    \n",
    "    @property\n",
    "    def _llm_type(self):\n",
    "        return \"ChatGLM3\"\n",
    "    \n",
    "    def load_model(self,modelPath=None):\n",
    "        #modelPath = \"D:\\\\ai\\\\download\\\\chatglm3\"\n",
    "        #modelPath = \"I:\\wd0717\\wendaMain\\wenda\\model\\chatglm3-6b\"\n",
    "        #配置分词器\n",
    "        tokenizer = AutoTokenizer.from_pretrained(modelPath,trust_remote_code=True,use_fast=True)\n",
    "\n",
    "        #加载模型\n",
    "        model = AutoModel.from_pretrained(modelPath,trust_remote_code=True,device_map=\"auto\")\n",
    "\n",
    "        model = model.eval()\n",
    "        \n",
    "        self.model = model\n",
    "        self.tokenizer = tokenizer\n",
    "        \n",
    "    def _call(self,prompt,config={},history=[]):\n",
    "        return self.invoke(prompt,history)\n",
    "    \n",
    "    def invoke(self,prompt,config={},history=[]):\n",
    "        if not isinstance(prompt, str):\n",
    "            prompt = prompt.to_string()\n",
    "        response,history = self.model.chat(\n",
    "            self.tokenizer,\n",
    "            prompt,\n",
    "            history=history,\n",
    "            do_sample=self.do_sample,\n",
    "            max_length=self.max_token,\n",
    "            temperature=self.temperature\n",
    "        )\n",
    "        self.history = history\n",
    "        return AIMessage(content=response)\n",
    "        \n",
    "    def stream(self,prompt,config={},history=[]):\n",
    "        if not isinstance(prompt, str):\n",
    "            prompt = prompt.to_string()\n",
    "        preResponse = \"\"\n",
    "        for response,new_history in self.model.stream_chat(self.tokenizer,prompt):\n",
    "            #self.history = new_history\n",
    "            \n",
    "            if preResponse == \"\":\n",
    "                result = response\n",
    "            else:\n",
    "                result = response[len(preResponse):]\n",
    "            preResponse = response\n",
    "            yield result\n",
    "        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "4208e574",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "587c05dd590946e4a4ac08981716cdae",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:root:Some parameters are on the meta device device because they were offloaded to the cpu.\n"
     ]
    }
   ],
   "source": [
    "llm = ChatGLM3()\n",
    "modelPath = \"I:\\wd0717\\wendaMain\\wenda\\model\\chatglm3-6b\"\n",
    "llm.load_model(modelPath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "46df3eae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "AIMessage(content='中国的首都是北京。')"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#调用call方法\n",
    "llm.invoke(\"中国的首都是？\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "7d13d7f2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "春节将至喜洋洋，\n",
      "鞭炮声声迎新春。\n",
      "家家户户张灯结彩，\n",
      "熙熙攘攘贺新年。\n",
      "\n",
      "走亲访友送礼物，\n",
      "亲朋好友欢聚一堂。\n",
      "笑声不断祝福送，\n",
      "幸福安康乐无疆。\n",
      "\n",
      "春节到来春意浓，\n",
      "万物复苏生机勃。\n",
      "祝福祖国繁荣昌盛，\n",
      "人民幸福安康。"
     ]
    }
   ],
   "source": [
    "for response in llm.stream(\"写一首诗春节的诗\"):\n",
    "    print(response,end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "00195179",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "prompt = ChatPromptTemplate.from_template(\"请根据下面的主题写一篇小红书营销的短文： {topic}\")\n",
    "output_parser = StrOutputParser()\n",
    "\n",
    "chain = prompt | llm | output_parser\n",
    "\n",
    "# chain.invoke({\"topic\": \"康师傅绿茶\"})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "63e8d43e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "【康师傅绿茶，让你尽享清新时光】\n",
      "\n",
      "夏日的阳光，炽热而耀眼，人的心情也随着温度上升而变得烦躁。这时，一瓶康师傅绿茶出现在你的面前，仿佛一股清流涌入你的生活。\n",
      "\n",
      "康师傅绿茶，自诞生之初就立志于为人们带来清新、舒适的时光。它采用严格筛选的绿茶原料，通过传统工艺制作而成，每一口都散发着清新的茶香，让你在炎炎夏日中找到一丝凉意。\n",
      "\n",
      "而且，康师傅绿茶不仅口感清新，还具有丰富的营养价值。绿茶富含多种维生素和矿物质，可以帮助消除疲劳，提高免疫力，让你在面对炎热的夏天时，保持活力充沛。\n",
      "\n",
      "与此同时，康师傅绿茶还致力于环保，采用可回收的包装设计，让你在享受清新的茶时光的同时，也为地球做出贡献。\n",
      "\n",
      "在这个夏天，让我们一起品味康师傅绿茶，让清新时光伴你左右。让我们一起享受绿茶带来的清新时光，让生活更加美好。"
     ]
    }
   ],
   "source": [
    "for chunk in chain.stream({\"topic\": \"康师傅绿茶\"}):\n",
    "    print(chunk,end=\"\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56455360",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be507681",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9ca72829",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e431df0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
